Skip to content

[Misc][DP] Fix AsyncLLM metrics for multi-API server deployments#6

Merged
njhill merged 14 commits intonjhill:all-to-allfrom
kouroshHakha:kh/fix-a2a-metrics
May 16, 2025
Merged

[Misc][DP] Fix AsyncLLM metrics for multi-API server deployments#6
njhill merged 14 commits intonjhill:all-to-allfrom
kouroshHakha:kh/fix-a2a-metrics

Conversation

@njhill
Copy link
Copy Markdown
Owner

@njhill njhill commented May 13, 2025

No description provided.

Signed-off-by: kouroshhakha <kourosh@anyscale.com>
Signed-off-by: kouroshhakha <kourosh@anyscale.com>
@github-actions
Copy link
Copy Markdown

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

Signed-off-by: kouroshhakha <kourosh@anyscale.com>
Copy link
Copy Markdown
Owner Author

@njhill njhill left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for this @kouroshHakha.

I have a few general comments:

When there is a single api-server, we log metrics that come back from each engine separately with the engine index as a label. In the multi-api-server PR I changed the logic on the engine side to only send it's SchedulerStats back to one of the client api-servers. So hopefully for the metrics corresponding to these, not much more should be needed apart from making sure the gauges use "mostrecent" mode.

However, other metrics are computed during the loop in async_llm.py based on the requests that were processed in that iteration (in IterationStats). With multiple API servers, the processing of these requests for a given engine will in general be distributed amongst the api-servers since the outputs are sent back to each based on which originally sent the request.

These we may have to look at more closely and on a case-by-case basis since for example some are histograms where we assume the count corresponds to the number of iterations that have run on the corresponding engine, and we'll now be recording multiple of these (could be between 1 and num api servers).

From your PR description:

  1. Set "sum" as default mode for gauges with special handling for lora_info metric

I don't see this anywhere in the changes?

Comment on lines +371 to +375
def _create_counter(self, name: str, documentation: Optional[str],
labelnames: list[str]):
return prometheus_client.Counter(name=name,
documentation=documentation,
labelnames=labelnames)
Copy link
Copy Markdown
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's the purpose of adding this indirection if all of the args are just passed to the corresponding constructors?

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Mostly modularity. We can extend these, for example in this PR vllm-project#17925 we want to wrap these primitives with their Ray equivalent.

Copy link
Copy Markdown
Owner Author

@njhill njhill May 14, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It might make more sense to make this particular change in that other PR then since it's not directly related to this one. cc @markmc

Comment on lines +201 to +213
# Upon shutdown of this process, we should mark the process as dead
# See https://prometheus.github.io/client_python/multiprocess/
try:
import os

from prometheus_client import multiprocess

multiprocess.mark_process_dead(os.getpid())
logger.debug("Marked Prometheus metrics for process %d as dead",
os.getpid())
except Exception as e:
logger.error("Error during metrics cleanup: %s", str(e))

Copy link
Copy Markdown
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does it matter if we run this even when prometheus logging is disabled?

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so this part of the shutdown logic runs through and is a no-op when log-stats=False. So I if we want to be too pedantic about something like prometheus_client not being installed on the host we can only gate the logic on log_stats. But I think it's better to keep it like this. The only edge case I can think of is if prometheus_client package is not installed in which case the try-except block will just emit a logger.error that won't be raised unless the user wants. So it's fine by default.

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added more comments to clarify the choice.

Comment on lines +170 to +182
assert num_api_servers > 1
if "PROMETHEUS_MULTIPROC_DIR" not in os.environ:
# Make TemporaryDirectory for prometheus multiprocessing
# Note: global TemporaryDirectory will be automatically
# cleaned up upon exit.
global prometheus_multiproc_dir
prometheus_multiproc_dir = tempfile.TemporaryDirectory()
os.environ["PROMETHEUS_MULTIPROC_DIR"] = prometheus_multiproc_dir.name
else:
logger.warning("Found PROMETHEUS_MULTIPROC_DIR was set by user. "
"This directory must be wiped between vLLM runs or "
"you will find inaccurate metrics. Unset the variable "
"and vLLM will properly handle cleanup.")
Copy link
Copy Markdown
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should probably only do this is prometheus logging is disabled .. at least we probably shouldn't if log_stats is False.

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Setting the env var even if log_stats is false is fine? it doesn't hurt? Why do you think we should gate it on log_stats? or even to be more precise when Prometheus logger is used?

Comment on lines +171 to +172
labelnames=labelnames,
multiprocess_mode="all").labels(*labelvalues)
Copy link
Copy Markdown
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
labelnames=labelnames,
multiprocess_mode="all").labels(*labelvalues)
labelnames=labelnames).labels(*labelvalues)

@njhill
Copy link
Copy Markdown
Owner Author

njhill commented May 14, 2025

@kouroshHakha in particular I also don't think we should be using the "all" mode which just labels by pid.

@njhill
Copy link
Copy Markdown
Owner Author

njhill commented May 14, 2025

We also need to make sure that the PROMETHEUS_MULTIPROC_DIR env var is always propagated properly. It was tricky to get this right before because I think it needs to be set prior to importing prometheus.

The docs recommend setting it externally but we obviously don't want to have to require that.

This environment variable should be set from a start-up shell script, and not directly from Python (otherwise it may not propagate to child processes).

@kouroshHakha
Copy link
Copy Markdown

From your PR description: Set "sum" as default mode for gauges with special handling for lora_info metric
I don't see this anywhere in the changes?

I think I changed it to livemostrecent (forgot to update the description)

https://github.com/njhill/vllm/pull/6/files#diff-43531f6ec44e98f78e8f3fd53839a31c4fe4dd1c1b9015a45bf88b6f5dfeaabdR358

Signed-off-by: kouroshhakha <kourosh@anyscale.com>
@kouroshHakha
Copy link
Copy Markdown

@njhill

Comparing -asc=1 and -asc=16 and a counter metric like num_prompt_tokens over time on a fixed workload that has ~2M input tokens.

Screenshot 2025-05-14 at 8 47 17 PM

@kouroshHakha
Copy link
Copy Markdown

kouroshHakha commented May 15, 2025

These we may have to look at more closely and on a case-by-case basis since for example some are histograms where we assume the count corresponds to the number of iterations that have run on the corresponding engine, and we'll now be recording multiple of these (could be between 1 and num api servers).

@njhill I have thought about this for almost the entire day looking at metrics and how they show up on grafana, etc. My conclusion is that we have only one metric that is impacted by hitting multiple iterationState records from different api_servers. That is histogram of vllm:iteration_tokens_total. Basically the main insight is that those metrics that are by definition invariant with notion of iteration step remain intact (e.g. time metrics, request level metrics, etc). But iteration_tokens_total does not fall into this category.

Screenshot 2025-05-14 at 9 38 52 PM

You can also observe the diff on vllm:iteration_tokens_total on the same workload. I can solve this by computing the n_tokens_per_iteration in each engine and attach it to the EngineCoreOutputs of only one of the clients. On the client side we can update the iterationstat.n_total_token_this_iteration only if it's returned by the EngineCoreOutputs. I am not sure if this is easy to do in the scheduler actually

@markmc
Copy link
Copy Markdown

markmc commented May 15, 2025

Nice work thinking this through @kouroshHakha

Let's work through a simple example.

With a single API server and single engine

req1: prompt_len=100
req2: prompt_len=50
req3: prompt_len=80

iter1:
  req1: is_prefill=True, new_tokens=1
  req2: is_prefill=True, new_tokens=1
      observe(152)

iter2: 
  req1: is_prefill=False, new_tokens=1
  req2: is_prefill=False, new_tokens=1
  req3: is_prefill=True, new_tokens=1
    observe(83)

iter3:
  req1: is_prefill=False, new_tokens=1
  req2: is_prefill=False, new_tokens=1
  req3: is_prefill=False, new_tokens=1
    observe(3)

iter4:
  req1: is_prefill=False, new_tokens=1
  req2: is_prefill=False, new_tokens=1
  req3: is_prefill=False, new_tokens=1
    observe(3)

result:
  count = 4
  sum = 241
  buckets:
    le=8: 2
    le=128: 3
    le=256: 4

With 2 API servers and 2 engines

req1: prompt_len=100  ===> routed to API1, Engine1
req2: prompt_len=50  ===> routed to API2, Engine2
req3: prompt_len=80  ===> routed to API1, Engine1

API1, Engine1 iter1:
  req1: is_prefill=True, new_tokens=1
      observe(101)
      
API2, Engine2 iter1:
  req2: is_prefill=True, new_tokens=1
      observe(51)

API1, Engine1 iter2: 
  req1: is_prefill=False, new_tokens=1
  req3: is_prefill=True, new_tokens=1
    observe(82)

API2, Engine2 iter2: 
  req2: is_prefill=False, new_tokens=1
      observe(1)

API1, Engine1 iter3:
  req1: is_prefill=False, new_tokens=1
  req3: is_prefill=False, new_tokens=1
    observe(2)

API2, Engine2 iter3:
  req2: is_prefill=False, new_tokens=1
    observe(1)
    
 API1, Engine1 iter4:
  req1: is_prefill=False, new_tokens=1
  req3: is_prefill=False, new_tokens=1
    observe(2)

API2, Engine2 iter4:
  req2: is_prefill=False, new_tokens=1
    observe(1)

result:
  count = 8
  sum = 241
  buckets:
    le=3: 
    le=8: 5
    le=64: 6
    le=128: 8

So the view you get is that the same number of tokens is being generated with more, smaller iterations?

Is that going to be a problem? Surely people are watching trends, or comparing across like-for-like instances, etc. rather than relying on the actual values?

e.g. sure, you'd see a drop if you rolled out a multi-api-server change ... but that might even be reassuring, and make a ton of sense?

wdyt?

@markmc
Copy link
Copy Markdown

markmc commented May 15, 2025

On the code ... I absolutely detest all this "make sure PROMETHEUS_MULTIPROC_DIR env var is set before importing prometheus_client" nonsense!

Firstly, I'm skeptical that the env var needs to be set before importing prometheus_client? Yes, we need to set it before creating the first metric, but why before importing? Especially since we're not using prometheus_client.REGISTRY?

Maybe I'm missing something there, but I'd like to be really sure we have to lazy import ... that's always going to be super brittle

Secondly, if we can put all the prometheus multiproc nonsense in one place - e.g. vllm.v1.metrics.prometheus - then it becomes much more maintainable. In V0, it was sprinkled over a bunch of places. I'd prefer to see no mention of PROMETHEUS_MULTIPROC_DIR anywhere except in vllm.v1.metrics.prometheus

Does that make sense?

@njhill
Copy link
Copy Markdown
Owner Author

njhill commented May 15, 2025

Thanks @kouroshHakha @markmc for that careful analysis!

I can solve this by computing the n_tokens_per_iteration in each engine and attach it to the EngineCoreOutputs of only one of the clients. On the client side we can update the iterationstat.n_total_token_this_iteration only if it's returned by the EngineCoreOutputs. I am not sure if this is easy to do in the scheduler actually

I'm not sure we want to complicate the scheduler or introduce more work there... we should aim to avoid that if possible.

Is that going to be a problem? Surely people are watching trends, or comparing across like-for-like instances, etc. rather than relying on the actual values?

I'm not sure, it's possible the rate of the "count" of this histogram could be used to track iteration frequency and used in various other derived "per iteration" metrics. Unfortunately the count also won't be a constant multiple of the "original" count since the number of times it's recorded per engine per iteration would vary depending on how the requests are distributed between api servers / engines.

If this is the only problematic one though it seems reasonable to not block the PR and just document this along with the multi-api-server option.

Firstly, I'm skeptical that the env var needs to be set before importing prometheus_client? Yes, we need to set it before creating the first metric, but why before importing? Especially since we're not using prometheus_client.REGISTRY?

I'm not sure exactly, and agree that's horrible, it's just a (possibly incorrect) recollection from when we were wrangling with this some time back with V0. Hopefully it's not really the case

Secondly, if we can put all the prometheus multiproc nonsense in one place - e.g. vllm.v1.metrics.prometheus - then it becomes much more maintainable. In V0, it was sprinkled over a bunch of places. I'd prefer to see no mention of PROMETHEUS_MULTIPROC_DIR anywhere except in vllm.v1.metrics.prometheus

Good point and I very much agree with this! If we do end up having an import ordering issue it should make that easier to manage too.

@kouroshHakha
Copy link
Copy Markdown

kouroshHakha commented May 15, 2025

Hey @markmc,

So the view you get is that the same number of tokens is being generated with more, smaller iterations?

yep exactly. I liked your toy example. It's really illuminating. Let's consider a scenario that engine is shared between two api servers. I think the conclusion would be the same tho:

Two API severs but one engine (all-to-all)

req1: prompt_len=100  ===> routed to API1, Engine1
req2: prompt_len=50  ===> routed to API2, Engine1
req3: prompt_len=80  ===> routed to API1, Engine1


Engine1 iter1:
    req1 (api1): is_prefill=True, new_tokens=1
    req2 (api2): is_prefill=True, new_tokens=1
    API1: observe(101)
    API2: observe(51)

Engine1 iter2:
    req1 (api1): is_prefill=False, new_tokens=1
    req2 (api2): is_prefill=False, new_tokens=1
    req3 (api1): is_prefill=True, new_tokens=1
    API1: observe(82)
    API2: observe(1)

Engine1 iter3:
    req1 (api1): isprefill=False, new_tokens=1
    req2 (api2): is_prefill=False, new_tokens=1
    req3 (api1): is_prefill=False, new_tokens=1
    API1: observe(2)
    API2: observe(1)

Engine1 iter4:
    req1 (api1): is_prefill=False, new_tokens=1
    req2 (api2): is_prefill=False, new_tokens=1
    req3 (api1): is_prefill=False, new_tokens=1
    API1: observe(2)
    API2: observe(1)



result:
  count = 8
  sum = 241
  buckets:
    le=1: 5
    le=8: 0
    le=16: 0
    le=32: 1
    le=64: 2
    le=128: 0


vs. with only one api server

result:
  count = 4
  sum = 241
  buckets:
    le=1: 2
    le=8: 0
    le=16: 0
    le=32: 0
    le=64: 1
    le=128: 1

Whether this is a problem or not, goes back to what we want from the system. I think keeping as is, is a reasonable design choice. i.e. As you scale the api server, the total sum remains the same, but it's broken down into smaller steps therefore changing the histogram. If we want to go this route, I think the metric name is a bit inconsistent with what it represents. The iteration_token_total suggest that as long as the number of engines remain the same for each engine I should see the similar distribution of tokens processed per step by the engine. Tho what @njhill suggested also could be a problem. It depends on how much the end user relies on these metrics :)

But I do agree with you that this is certainly not a big problem. We should certainly not block this PR for this @njhill.

Secondly, if we can put all the prometheus multiproc nonsense in one place - e.g. vllm.v1.metrics.prometheus - then it becomes much more maintainable. In V0, it was sprinkled over a bunch of places. I'd prefer to see no mention of PROMETHEUS_MULTIPROC_DIR anywhere except in vllm.v1.metrics.prometheus

@njhill @markmc I also think this was actually not necessary. Because once I added the check of sys.module during development I noticed I was failing that check but my metrics were still correct. So I don't think having to define the env var before an import is an absolute requirements. It was that based on past comments on V0 code path + the note written on prometheus docs https://prometheus.github.io/client_python/multiprocess/ (This environment variable should be set from a start-up shell script, and not directly from Python (otherwise it may not propagate to child processes).)

I will double check this and if the lazy import doesn't end up as a requirement I'd remove it like suggested. This will allow us to self contain Prometheus parts in one place for easier maintenance as well.

Copy link
Copy Markdown
Owner Author

@njhill njhill left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @kouroshHakha

as well as the inline comments it would be good to make the change suggested by @markmc to move all the prometheus-touching logic into the prometheus package, calling utility methods from there as needed.

metrics_info["engine"] = self.engine_index

name, documentation = None, None
multiprocess_mode = "mostrecent"
Copy link
Copy Markdown
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's the reason for this variable?

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

waiting for a green light from you to refactor this function entirely. It's a bit wierd that it's doing some conditioning but the condition is always true when I search across the project globally.

Copy link
Copy Markdown
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would say just leave it at least for this PR? Could always look into it some more and open a separate PR to refactor...

Comment on lines +475 to +476
def build_buckets(mantissa_lst: list[int],
max_value: int) -> list[Union[int, float]]:
Copy link
Copy Markdown
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's the reason for these changes? I'm not sure what's wrong with list[int] and regardless I think it's unrelated to the PR purpose?

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

reverting. linter artifact when I had the indirection

self.labelname_running_lora_adapters,
])
],
multiprocess_mode="livemostrecent"
Copy link
Copy Markdown
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this lora_info gauge another potentially problematic one? Since it's updated via IteratonStats and not SchedulerStats ... So I'm not sure "livemostrecent" is the right thing to use here.

I haven't looked closely at what this metric / how it is computed, maybe "sum" would fit better. I have a feeling though even that might not be correct, because it may be counting e.g. the number of unique lora adapters across the running requests and so it's not really possible to just combine the separate counts when the requests from a given engine are partitioned.

If there's not an easy answer we could include this in the list that we document as not being correct when mutli-api-servers are in play.

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In v0 world, this was livemostrecent. I assumed it was intentional, so I kept it.
https://github.com/vllm-project/vllm/blob/main/vllm/engine/metrics.py#L80

It is potentially one of those problematic ones. But supporting lora metrics even falls lower in priority than the other one. So we may as well just record it in the docs. I'd keep as is since it's coming from historical context anyways.

Copy link
Copy Markdown
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In V0 I don't think were were ever logging the same metrics in multiple places, so the aggregation mode was probably irrelevant. I do think "sum" is probably slightly less bad here. But probably we should just disable this metric when there are multiple API servers since I think the values will just be wrong (and can document of course).

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I made it sum. We can simply document for now.

… module

Signed-off-by: kouroshhakha <kourosh@anyscale.com>
Signed-off-by: kouroshhakha <kourosh@anyscale.com>
@kouroshHakha
Copy link
Copy Markdown

ok @njhill @markmc I separated out prometheus non sense into its own python module under v1.metrics.prometheus. Also tested some of the metrics in single vs. multi process scenarios that I found brittle before and can confirm doing eager import of promptheus_client the way it is done in this PR is good.

We can follow up with moving PrometheusStatLogger to v1.metrics.prometheus later, but I am avoiding this in this PR now to keep the diffs small.

Signed-off-by: kouroshhakha <kourosh@anyscale.com>
Copy link
Copy Markdown
Owner Author

@njhill njhill left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @kouroshHakha

return REGISTRY


def mount_metrics(app):
Copy link
Copy Markdown
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
def mount_metrics(app):
def mount_metrics(app: FastAPI):

Copy link
Copy Markdown
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it might be better to leave this method in api_server.py, and just call get_prometheus_registry() from here

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So this method is very prometheus heavy (meaning it's not just the registry but also other things like mak_asgi_app, etc that are coming from prometheus. This really belongs to prometheus module of vllm. WDYT?

app.routes.append(metrics_route)


def mark_process_dead(pid):
Copy link
Copy Markdown
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
def mark_process_dead(pid):
def mark_process_dead(pid: int):

registry = CollectorRegistry()
multiprocess.MultiProcessCollector(registry)
return registry
else:
Copy link
Copy Markdown
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: redundant else

Comment on lines +203 to +208
try:
mark_process_dead(os.getpid())
logger.debug("Marked Prometheus metrics for process %d as dead",
os.getpid())
except Exception as e:
logger.error("Error during metrics cleanup: %s", str(e))
Copy link
Copy Markdown
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

put the try/except inside the method too?

Copy link
Copy Markdown

@kouroshHakha kouroshHakha May 16, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the caller should decide what they want to do. Sg?

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agree with Nick. There's nothing the caller can do, and wrapping anything in such a broad try/except implies some knowledge of what the function is doing. I'd move the os.getpid() into the prometheus module too and call it shutdown_prometheus() or similar

Signed-off-by: kouroshhakha <kourosh@anyscale.com>
logger.debug("vLLM to use %s as PROMETHEUS_MULTIPROC_DIR",
prometheus_multiproc_dir_path)
registry = CollectorRegistry()
multiprocess.MultiProcessCollector(registry)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Everything above should be in the prometheus module, returning a registry


# Workaround for 307 Redirect for /metrics
metrics_route.path_regex = re.compile("^/metrics(?P<path>.*)$")
app.routes.append(metrics_route)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And all of the rest of it should remain in the API server module, using registry returned from the prometheus module

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After doing this, it looks nicer, I have to admit :D

# Unregister any existing vLLM collectors
for collector in list(REGISTRY._collector_to_names):
if hasattr(collector, "_name") and "vllm" in collector._name:
REGISTRY.unregister(collector)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One nice thing about having all this nonsense together in one module, you wonder ...

If we're using REGISTRY here, we're assuming multiprocess mode is not enabled? Maybe assert that the env var is not set?

Copy link
Copy Markdown
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good catch .. logic here should differ here in multproc case I think?

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good catch :)

Signed-off-by: kouroshhakha <kourosh@anyscale.com>
Signed-off-by: kouroshhakha <kourosh@anyscale.com>
@kouroshHakha
Copy link
Copy Markdown

@njhill @markmc ready. incorporated your feedbacks.

Signed-off-by: kouroshhakha <kourosh@anyscale.com>
Copy link
Copy Markdown

@markmc markmc left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Super minor comments

from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse, Response, StreamingResponse
from prometheus_client import make_asgi_app
from prometheus_fastapi_instrumentator import Instrumentator
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove?

Copy link
Copy Markdown

@kouroshHakha kouroshHakha May 16, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can't remove. mount_metrics is added back here, so it's needed. That's why I wanted to keep mount_metrics entirely in prometheus.

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Doh, I misread, sorry. That's fine, this stuff is in the "api server" category not the "disgusting prometheus multi proc hackery" category 😃

"""Mark a process as dead in prometheus multiprocessing.

Args:
pid: Process ID to mark as dead
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done already.

self.gauge_scheduler_running = prometheus_client.Gauge(
name="vllm:num_requests_running",
documentation="Number of requests in model execution batches.",
labelnames=labelnames).labels(*labelvalues)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you're re-spinning again, then a very minor stylistic request ...

In the original version this line is all "label stuff"

In the new version, it becomes "label stuff, new line, multiproc stuff, label stuff"

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sorry did not quite understand the desired style?

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh you mean the order? put multiprocess_mode before label stuff?

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok done.

documentation="Number of requests in model execution batches.",
labelnames=labelnames).labels(*labelvalues)
labelnames=labelnames,
multiprocess_mode="mostrecent").labels(*labelvalues)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggest (for all these multiprocess_mode changes)

            multiprocess_mode="mostrecent",
            labelnames=labelnames).labels(*labelvalues)

Copy link
Copy Markdown
Owner Author

@njhill njhill left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Couple more small comments since @markmc also had some :)

Could you also add a similar warning comment to the iteration_tokens_total metric?

self.labelname_running_lora_adapters,
])
],
multiprocess_mode="sum"
Copy link
Copy Markdown
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you add a comment here (maybe above, as part of the "LoRA metrics" comment), explaining that this metric will not be correct when using api-server scaleout, which uses prometheus mp mode.

Signed-off-by: kouroshhakha <kourosh@anyscale.com>
Signed-off-by: kouroshhakha <kourosh@anyscale.com>
@njhill
Copy link
Copy Markdown
Owner Author

njhill commented May 16, 2025

Thanks for all of your help and patience with this @kouroshHakha!

@njhill njhill merged commit 1bf3a63 into njhill:all-to-all May 16, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants